################################################################################
#       Copyright (C) 2010 Michael Yurko <myurko@gmail.com>
#
#This file is part of pyQAP.
#
#pyQAP is free software: you can redistribute it and/or modify
#it under the terms of the GNU General Public License as published by
#the Free Software Foundation, either version 2 of the License, or
#(at your option) any later version.
#
#pyQAP is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#GNU General Public License for more details.
#
#You should have received a copy of the GNU General Public License
#along with pyQAP.  If not, see <http://www.gnu.org/licenses/>.
################################################################################

import time
import classes
import vars
import heuristics
import bounds
import lap
import multiprocessing as mp
from collections import deque
from math import factorial
import os
import sys
'''
Contains code for the main backtracking QAP solver
and associated bound and heuristic setters.
'''

def solve(problem, bound=bounds.none,
          heuristic=heuristics.first, verbose = vars.verbose):
    """
    A function to solve a quadratic assignment problem.
    
    Returns the optimal solution to the given Problem. When explicitly
    specified, the given lower bound and heuristic are used to help prune the 
    search tree.
    
    :param problem: The QAP Problem to solve.
    :type problem: Problem
    :param bound: The lower bound to use
    :type bound: A function
    :param heuristic: The initial heuristic to use
    :type heuristic: A function
    :param verbose: Whether to print intermediate results. The default is set
    int vars.
    :returns: The Solution to the given problem.
    :rtype: Solution
    
    Examples:
    
    >>> from pyQAP import *
    >>> prob = problems.QAPLIB('nug12')
    >>> prob = problems.reduced(prob,6)
    >>> solve(prob, bounds.none, verbose = False)
    QAP Solution: [...] Cost: 94.000000 Nodes visited: 1958.000000 Total time: ...
    >>> prob = problems.QAPLIB('jen4')
    >>> solve(prob, bounds.none, verbose = False)
    QAP Solution: [... Cost: 3260.000000 Nodes visited: 66.000000 Total time: ...
    """ 
    #unpack vars from problem
    f = problem.f
    d = problem.d
    
    #init more vars
    nodes = 1
    n = problem.n
    
    #add heuristic solution
    first_sol = heuristic(problem)
    best_yet = first_sol.cost(problem)
    best_sol = first_sol.sol
    
    #initialize stack of PSolutions
    stack = [classes.PSolution(n, problem) for i in xrange((n*n+n)/2)]
    top = 0
    
    #start timer
    start = time.clock()

    #start backtracking
    while top >= 0:
        nodes += 1
        psol = stack[top]
        top -= 1
        if psol.placed == n:
            c = psol.cost(problem)
            if c < best_yet:
                best_yet = c
                best_sol = psol.sol.copy()
                if verbose:
                    print str(best_sol) + " is better."
                    print str(best_yet) + " is the new cost."
        else:
            if bound(psol, problem) < best_yet:
                top+=1
                #if not fathomed assign all unused facilities in new nodes
                #find first
                for i in xrange(n):
                    if psol.unplaced[i]:
                        first = i
                        break
                #expand other nodes
                for i in xrange(first+1, n):
                    if psol.unplaced[i]:
                        top += 1
                        psol.change_new(i, stack[top])
                #change old
                psol.sol[psol.placed] = first
                psol.unplaced[first] = False
                psol.placed += 1
    
#    print best_sol
    return classes.Solution(best_sol, best_yet, nodes, time.clock()-start)
    
def search_sub_tree(root):
    '''
    A helper function for parallel_solve().
    
    This function takes the given root and explores it using the same
    backtracking procedure as solve(). Note: It uses global variables in ctx to
    store information on the best solution yet. It is unlikely that this
    is useful outside of parallel_solve().
    
    :param root: the root of the part of the tree to explore
    :type root: PSolution
    :retuns: the number of nodes that were explored
    :rtype: int
    
    Examples:
    
    Isn't used outside of parallel_solver.
    
    '''
    global ctx
    #init
    n = ctx.problem.n
    nodes = 0
    best = ctx.best_yet.value
    iter = 0
    
    #improve memory locality
    prob = ctx.problem
    lb = ctx.bound
    
    #initialize stack of PSolutions
    global stack
    stack[0] = root
    top = 0
    
    #start backtracking
    while top >= 0:
        iter += 1
        if iter >= ctx.interval:
            best = ctx.best_yet.value
            iter = 0
        psol = stack[top]
        top -= 1
        nodes += 1
        if psol.placed == n:
            c = psol.cost(prob)
            if c < best:
                if c < ctx.best_yet.value:
                    with ctx.best_write_lock:
                        ctx.best_yet.value = c
                        for i in xrange(n):
                            ctx.best_sol[i] = psol.sol[i]
                    best = c
                    if ctx.verbose:
                        print str(psol.sol) + " is better."
                        print str(ctx.best_yet.value) + " is the new cost."
        else:
            if lb(psol, prob) < best:
                top += 1
                #if not fathomed assign all unused facilities in new nodes
                #find first
                for i in xrange(n):
                    if psol.unplaced[i]:
                        first = i
                        break
                #expand other nodes
                for i in xrange(first+1, n):
                    if psol.unplaced[i]:
                        top += 1
                        psol.change_new(i, stack[top])
                #change old
                psol.sol[psol.placed] = first
                psol.unplaced[first] = False
                psol.placed += 1
                
    ctx.nodes.value += nodes
    
def parallel_solve(problem, threads, bound=bounds.none,
                   heuristic=heuristics.first,
                   verbose=vars.verbose, poll_interval=100, 
                   pieces = lambda t: t*t*t):
    '''
    Solves a given QAP in parallel.
    
    This function solves the given Problem in parallel using the given number
    of threads. It works by splitting the backtracking tree into a series of
    nodes in a breadth-first manner until a minimum number of nodes are created.
    If it is imposible to create the minimum number of nodes with the given
    problem (since it is too small) this will fall back to the single-threaded
    version, solve(). When explicitly specified, the given lower bound and
    heuristic are used to help prune the search tree.
    
    :param problem: the QAP to solve
    :type problem: Problem
    :param threads: the number of threads to use
    :type threads: int
    :param bound: The lower bound to use
    :type bound: A function
    :param heuristic: The initial heuristic to use
    :type heuristic: A function
    :param verbose: whether or not to print intermediate solutions
    :type verbose: boolean
    :param poll_interval: the number of iterations before the helper processes
    synchronize with the global data.
    :type poll_interval: int
    :param pieces: takes the number of threads and returns the minimum number of
    nodes to initially create
    :type pieces: function
    
    Examples:
    
    >>> import problems as p
    >>> import bounds as b
    >>> prob = p.QAPLIB('nug12')
    >>> prob = p.reduced(prob,6)
    >>> parallel_solve(prob, 2, bound=b.gilmore_lawler, verbose = False)
    QAP Solution: ... Cost: 94.000000 Nodes visited: ... Total time: ...
    >>> prob = p.QAPLIB('jen4')
    >>> parallel_solve(prob, 2, bound=b.gilmore_lawler, verbose = False)
    QAP Solution: ... Cost: 3260.000000 Nodes visited: ... Total time: ...
    '''
    
    #start timer
    start = time.time()
    
    #check if problem is large enough, use single thread otherwise
    chunks = pieces(threads)
    if factorial(problem.n) <= chunks:
        return solve(problem)
    
    #unpack vars from problem
    f = problem.f
    d = problem.d
    n = problem.n
    
    #generate nodes to feed to sub-solvers
    q = deque([classes.PSolution(n, problem)])
    while len(q) <= chunks:
        current = q.popleft()
        for i in xrange(n):
            if current.unplaced[i]:
                new = current.add_new(i)
                if bound(new, problem):
                    q.append(current.add_new(i))

    #multiprocessing setup
    manager = mp.Manager()
    
    global ctx
    ctx = classes.ParallelContext()
    def init_pool(best_yet, best_sol, best_write_lock, problem,
                  interval, verbose, nodes, bound):
        ctx.best_yet = best_yet
        ctx.best_sol = best_sol
        ctx.best_write_lock = best_write_lock
        ctx.problem = problem
        ctx.interval = interval
        ctx.verbose = verbose
        ctx.nodes = nodes
        ctx.bound = bound
        os.nice(vars.niceness)
        global stack
        stack = [classes.PSolution(n, problem=ctx.problem)\
             for i in xrange((n*n+n)/2)]
        
    best_yet = mp.Value('d', lock = False)
    best_yet.value = sys.maxint
    best_sol = mp.Array('i', n, lock = False)
    best_write_lock = mp.RLock()
    nodes = mp.Value('i')
    nodes.value = len(q)
    pool = mp.Pool(threads, init_pool, (best_yet, best_sol, best_write_lock,
                                        problem, poll_interval, verbose, nodes,
                                        bound))
    
    #add heuristic solution
    first_sol = heuristic(problem)
    best_yet.value = first_sol.cost(problem)
    for i in xrange(n):
        best_sol[i] = first_sol.sol[i]
    
    #do actual work
    results = pool.map_async(search_sub_tree, q)
    
    #wait until last chunk is done
    results.wait()
    
    return classes.Solution([i for i in best_sol], best_yet.value, nodes.value,
                             time.time()-start)